import models.resnet_models as models
import structured_my
from thop import profile
import strategy_snip
import logging
import torch.nn as nn
import torch
import numpy as np

def prune_model_18(model, args, train_loader):
    device = next(model.parameters()).device
    # model = model.module.cpu()
    block_name=model.layer1[0]._get_name()
    if args.score == 'snip':
        strategy = strategy_snip.SNIPStrategy()
    elif args.score == 'random':
        strategy = strategy_snip.RandomStrategy()
    else:
        raise NotImplementedError
    pass
    def prune_BottleneckBlock_with_downsample(layer, rate):
        idxs1 = strategy(layer.conv1.weight, amount=rate)
        structured_my.prune_conv(layer.conv1, idxs1)
        structured_my.prune_batchnorm(layer.bn1, idxs1)
        structured_my.prune_related_conv(layer.conv2, idxs1)

        idxs2 = strategy(layer.conv2.weight, amount=rate)
        structured_my.prune_conv(layer.conv2, idxs2)
        structured_my.prune_batchnorm(layer.bn2, idxs2)
        structured_my.prune_related_conv(layer.conv3, idxs2)

        idxs3 = strategy(layer.conv3.weight, amount=rate)
        structured_my.prune_conv(layer.conv3, idxs3)
        structured_my.prune_batchnorm(layer.bn3, idxs3)
        structured_my.prune_conv(layer.downsample[0], idxs3)
        structured_my.prune_batchnorm(layer.downsample[1], idxs3)
        return idxs3

    def prune_BottleneckBlock(layer, rate):
        idxs1 = strategy(layer.conv1.weight, amount=rate)
        structured_my.prune_conv(layer.conv1, idxs1)
        structured_my.prune_batchnorm(layer.bn1, idxs1)
        structured_my.prune_related_conv(layer.conv2, idxs1)

        idxs2 = strategy(layer.conv2.weight, amount=rate)
        structured_my.prune_conv(layer.conv2, idxs2)
        structured_my.prune_batchnorm(layer.bn2, idxs2)
        structured_my.prune_related_conv(layer.conv3, idxs2)

        idxs3 = strategy(layer.conv3.weight, amount=rate)
        structured_my.prune_conv(layer.conv3, idxs3)
        structured_my.prune_batchnorm(layer.bn3, idxs3)
        return idxs3

    def prune_BasicBlock_with_downsample(layer, rate):
        # print(rate)
        idxs1 = strategy(layer.conv1.weight, amount=rate)
        structured_my.prune_conv(layer.conv1, idxs1)
        structured_my.prune_batchnorm(layer.bn1, idxs1)
        structured_my.prune_related_conv(layer.conv2, idxs1)

        idxs2 = strategy(layer.conv2.weight, amount=rate)
        structured_my.prune_conv(layer.conv2, idxs2)
        structured_my.prune_batchnorm(layer.bn2, idxs2)
        structured_my.prune_conv(layer.downsample[0], idxs2)
        structured_my.prune_batchnorm(layer.downsample[1], idxs2)
        return idxs2

    def prune_BasicBlock(layer, rate):
        idxs1 = strategy(layer.conv1.weight, amount=rate)
        structured_my.prune_conv(layer.conv1, idxs1)
        structured_my.prune_batchnorm(layer.bn1, idxs1)
        structured_my.prune_related_conv(layer.conv2, idxs1)

        idxs2 = strategy(layer.conv2.weight, amount=rate)
        structured_my.prune_conv(layer.conv2, idxs2)
        structured_my.prune_batchnorm(layer.bn2, idxs2)
        return idxs2
    if block_name == 'BasicBlock':
        prune_block = prune_BasicBlock
        prune_block_with_downsample = prune_BasicBlock_with_downsample
    else:
        prune_block = prune_BottleneckBlock
        prune_block_with_downsample = prune_BottleneckBlock_with_downsample

    def prune_layer1(model, rate):
        idxs = prune_block(model.layer1[0], rate)
        for i in range(1, len(model.layer1)):
            structured_my.prune_related_conv(model.layer1[i].conv1, idxs)
            idxs = prune_block(model.layer1[i], rate)
        structured_my.prune_related_conv(model.layer2_1[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer2_1[0].downsample[0], idxs)
        structured_my.prune_related_conv(model.layer2_2[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer2_2[0].downsample[0], idxs)

    def prune_layer2_1(model, rate):
        idxs = prune_block_with_downsample(model.layer2_1[0], rate)
        for i in range(1, len(model.layer2_1)):
            structured_my.prune_related_conv(model.layer2_1[i].conv1, idxs)
            idxs = prune_block(model.layer2_1[i], rate)
        structured_my.prune_related_conv(model.layer3_1[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer3_1[0].downsample[0], idxs)
        structured_my.prune_related_conv(model.layer3_2[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer3_2[0].downsample[0], idxs)
        
    def prune_layer2_2(model, rate):
        idxs = prune_block_with_downsample(model.layer2_2[0], rate)
        for i in range(1, len(model.layer2_2)):
            structured_my.prune_related_conv(model.layer2_2[i].conv1, idxs)
            idxs = prune_block(model.layer2_2[i], rate)
        structured_my.prune_related_conv(model.layer3_3[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer3_3[0].downsample[0], idxs)
        structured_my.prune_related_conv(model.layer3_4[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer3_4[0].downsample[0], idxs)

    def prune_layer3_1(model, rate):
        idxs = prune_block_with_downsample(model.layer3_1[0], rate)
        for i in range(1, len(model.layer3_1)):
            structured_my.prune_related_conv(model.layer3_1[i].conv1, idxs)
            idxs = prune_block(model.layer3_1[i], rate)
        structured_my.prune_related_conv(model.layer4_1[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer4_1[0].downsample[0], idxs)

    def prune_layer3_2(model, rate):
        idxs = prune_block_with_downsample(model.layer3_2[0], rate)
        for i in range(1, len(model.layer3_2)):
            structured_my.prune_related_conv(model.layer3_2[i].conv1, idxs)
            idxs = prune_block(model.layer3_2[i], rate)
        structured_my.prune_related_conv(model.layer4_2[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer4_2[0].downsample[0], idxs)

    def prune_layer3_3(model, rate):
        idxs = prune_block_with_downsample(model.layer3_3[0], rate)
        for i in range(1, len(model.layer3_3)):
            structured_my.prune_related_conv(model.layer3_3[i].conv1, idxs)
            idxs = prune_block(model.layer3_3[i], rate)
        structured_my.prune_related_conv(model.layer4_3[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer4_3[0].downsample[0], idxs)

    def prune_layer3_4(model, rate):
        idxs = prune_block_with_downsample(model.layer3_4[0], rate)
        for i in range(1, len(model.layer3_4)):
            structured_my.prune_related_conv(model.layer3_4[i].conv1, idxs)
            idxs = prune_block(model.layer3_4[i], rate)
        structured_my.prune_related_conv(model.layer4_4[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer4_4[0].downsample[0], idxs)

    def prune_layer4_1(model, rate):
        idxs = prune_block_with_downsample(model.layer4_1[0], rate)
        for i in range(1, len(model.layer4_1)):
            structured_my.prune_related_conv(model.layer4_1[i].conv1, idxs)
            idxs = prune_block(model.layer4_1[i], rate)
        structured_my.prune_related_linear(model.middle_fc1, idxs)

    def prune_layer4_2(model, rate):
        idxs = prune_block_with_downsample(model.layer4_2[0], rate)
        for i in range(1, len(model.layer4_2)):
            structured_my.prune_related_conv(model.layer4_2[i].conv1, idxs)
            idxs = prune_block(model.layer4_2[i], rate)
        structured_my.prune_related_linear(model.middle_fc2, idxs)

    def prune_layer4_3(model, rate):
        idxs = prune_block_with_downsample(model.layer4_3[0], rate)
        for i in range(1, len(model.layer4_3)):
            structured_my.prune_related_conv(model.layer4_3[i].conv1, idxs)
            idxs = prune_block(model.layer4_3[i], rate)
        structured_my.prune_related_linear(model.middle_fc3, idxs)

    def prune_layer4_4(model, rate):
        idxs = prune_block_with_downsample(model.layer4_4[0], rate)
        for i in range(1, len(model.layer4_4)):
            structured_my.prune_related_conv(model.layer4_4[i].conv1, idxs)
            idxs = prune_block(model.layer4_4[i], rate)
        structured_my.prune_related_linear(model.middle_fc4, idxs)


    def get_pruneprobs(layer, args, train_loader):
        n = len(layer)
        layer_score_blcok = torch.zeros(n)
        layer_para_num = torch.zeros(n)
        for i in range(n):
            for p in layer[i].modules():
                if isinstance(p, nn.Conv2d) or isinstance(p, nn.BatchNorm2d) or isinstance(p, nn.Conv2d):
                    SNIP = torch.sum(p.weight.grad.abs().flatten(), dim=0)
                    layer_para_num[i] = layer_para_num[i] + sum([np.prod(q.size()) for q in p.parameters()])
                    layer_score_blcok[i] = layer_score_blcok[i] + SNIP

        return torch.sum(layer_score_blcok), torch.sum(layer_para_num)


    if args.blockprobs is None:
        run_one_batch_4head(model, args, train_loader)
        layer1_imp, layer1_n_para = get_pruneprobs(model.layer1, args, train_loader)
        logging.info([layer1_imp, layer1_n_para])
        layer2_1_imp, layer2_1_n_para = get_pruneprobs(model.layer2_1, args, train_loader)
        logging.info([layer2_1_imp, layer2_1_n_para])
        layer2_2_imp, layer2_2_n_para = get_pruneprobs(model.layer2_2, args, train_loader)
        logging.info([layer2_2_imp, layer2_2_n_para ])
        layer3_1_imp, layer3_1_n_para = get_pruneprobs(model.layer3_1, args, train_loader)
        logging.info([layer3_1_imp, layer3_1_n_para])
        layer3_2_imp, layer3_2_n_para = get_pruneprobs(model.layer3_2, args, train_loader)
        logging.info([layer3_2_imp, layer3_2_n_para])
        layer3_3_imp, layer3_3_n_para = get_pruneprobs(model.layer3_3, args, train_loader)
        logging.info([layer3_3_imp, layer3_3_n_para])
        layer3_4_imp, layer3_4_n_para = get_pruneprobs(model.layer3_4, args, train_loader)
        logging.info([layer3_4_imp, layer3_4_n_para])
        layer4_1_imp, layer4_1_n_para = get_pruneprobs(model.layer4_1, args, train_loader)
        logging.info([layer4_1_imp, layer4_1_n_para])
        layer4_2_imp, layer4_2_n_para = get_pruneprobs(model.layer4_2, args, train_loader)
        logging.info([layer4_2_imp, layer4_2_n_para])
        layer4_3_imp, layer4_3_n_para = get_pruneprobs(model.layer4_3, args, train_loader)
        logging.info([layer4_3_imp, layer4_3_n_para ])
        layer4_4_imp, layer4_4_n_para = get_pruneprobs(model.layer4_4, args, train_loader)
        logging.info([layer4_4_imp, layer4_4_n_para])
        
        head1_all_para = layer1_n_para + layer2_1_n_para + layer3_1_n_para + layer4_1_n_para
        head2_all_para = layer1_n_para + layer2_1_n_para + layer3_2_n_para + layer4_2_n_para
        head3_all_para = layer1_n_para + layer2_2_n_para + layer3_3_n_para + layer4_3_n_para
        head4_all_para = layer1_n_para + layer2_2_n_para + layer3_4_n_para + layer4_4_n_para
        logging.info('para all')
        logging.info([head1_all_para,head2_all_para,head3_all_para,head4_all_para])
        head1_prune_probs = args.headprobs[0]
        head2_prune_probs = args.headprobs[1]
        head3_prune_probs = args.headprobs[2]
        head4_prune_probs = args.headprobs[3]
        
    
        head1_total_score=1/(layer1_imp/layer1_n_para)+1/(layer2_1_imp/layer2_1_n_para)+1/(layer3_1_imp/layer3_1_n_para)+1/(layer4_1_imp/layer4_1_n_para)
        layer1_prune_num=1/(layer1_imp/layer1_n_para)/head1_total_score * head1_all_para*head1_prune_probs
        layer1_rate = layer1_prune_num/layer1_n_para
        
        layer2_1_prune_num = 1 / (layer2_1_imp / layer2_1_n_para) / head1_total_score * head1_all_para * head1_prune_probs
        layer2_1_rate=layer2_1_prune_num/layer2_1_n_para
        layer3_1_prune_num = 1 / (layer3_1_imp / layer3_1_n_para) / head1_total_score * head1_all_para * head1_prune_probs
        layer3_1_rate=layer3_1_prune_num/layer3_1_n_para
        layer4_1_prune_num = 1 / (layer4_1_imp / layer4_1_n_para) / head1_total_score * head1_all_para * head1_prune_probs
        layer4_1_rate=layer4_1_prune_num/layer4_1_n_para
        
        # print(layer1_prune_num,layer1_rate)
        # print(layer2_1_prune_num,layer2_1_rate)
        # print(layer3_1_prune_num,layer3_1_rate)
        # print(layer4_1_prune_num,layer4_1_rate)
        # print(head1_all_para*head1_prune_probs,layer1_prune_num+layer2_1_prune_num+layer3_1_prune_num+layer4_1_prune_num)
    
    
        head2_total_score = 1 / (layer3_2_imp / layer3_2_n_para) + 1 / (layer4_2_imp / layer4_2_n_para)
        layer3_2_prune_num = 1 / (layer3_2_imp / layer3_2_n_para) / head2_total_score * (head2_all_para * head2_prune_probs-layer1_prune_num-layer2_1_prune_num)
        layer3_2_rate = layer3_2_prune_num/layer3_2_n_para
        layer4_2_prune_num = 1 / (layer4_2_imp / layer4_2_n_para) / head2_total_score * (head2_all_para * head2_prune_probs-layer1_prune_num-layer2_1_prune_num)
        layer4_2_rate = layer4_2_prune_num/layer4_2_n_para
        # print(layer3_2_prune_num,layer3_2_rate)
        # print(layer4_2_prune_num,layer4_2_rate)
        # print(head2_all_para*head2_prune_probs,layer1_prune_num+layer2_1_prune_num+layer3_2_prune_num+layer4_2_prune_num)
        # 
    
        head3_total_score=1/(layer2_2_imp/layer2_2_n_para)+1/(layer3_3_imp/layer3_3_n_para)+1/(layer4_3_imp/layer4_3_n_para)
        layer2_2_prune_num = 1 / (layer2_2_imp / layer2_2_n_para) / head3_total_score * (head3_all_para * head3_prune_probs - layer1_prune_num)
        layer2_2_rate=layer2_2_prune_num/layer2_2_n_para
        layer3_3_prune_num = 1 / (layer3_3_imp / layer3_3_n_para) / head3_total_score * (head3_all_para * head3_prune_probs - layer1_prune_num)
        layer3_3_rate=layer3_3_prune_num/layer3_3_n_para
        layer4_3_prune_num = 1 / (layer4_3_imp / layer4_3_n_para) / head3_total_score * (head3_all_para * head3_prune_probs - layer1_prune_num)
        layer4_3_rate=layer4_3_prune_num/layer4_3_n_para
        # print(layer2_2_prune_num,layer2_2_rate)
        # print(layer3_3_prune_num,layer3_3_rate)
        # print(layer4_3_prune_num, layer4_3_rate)
        # print(head3_all_para*head3_prune_probs,layer1_prune_num+layer2_2_prune_num+layer3_3_prune_num+layer4_3_prune_num)
        # 
    
        head4_total_score = 1/(layer3_4_imp/layer3_4_n_para)+1/(layer4_4_imp/layer4_4_n_para)
        layer3_4_prune_num = 1 / (layer3_4_imp / layer3_4_n_para) / head4_total_score * (head4_all_para * head4_prune_probs-layer1_prune_num-layer2_2_prune_num)
        layer3_4_rate = layer3_4_prune_num/layer3_4_n_para
        layer4_4_prune_num = 1 / (layer4_4_imp / layer4_4_n_para) / head4_total_score * (head4_all_para * head4_prune_probs-layer1_prune_num-layer2_2_prune_num)
        layer4_4_rate = layer4_4_prune_num/layer4_4_n_para
        # print(layer3_4_prune_num, layer3_4_rate)
        # print(layer4_4_prune_num, layer4_4_rate)
        # print(head4_all_para * head4_prune_probs,  layer1_prune_num + layer2_2_prune_num + layer3_4_prune_num + layer4_4_prune_num)

        if layer1_rate > 0.9:
            layer1_rate =0.9
        if layer2_1_rate > 0.9:
            layer2_1_rate =0.9
        if layer2_2_rate > 0.9:
            layer2_2_rate =0.9
        if layer3_1_rate > 0.9:
            layer3_1_rate =0.9
        if layer3_2_rate > 0.9:
            layer3_2_rate =0.9
        if layer3_3_rate > 0.9:
            layer3_3_rate =0.9
        if layer3_4_rate > 0.9:
            layer3_4_rate =0.9
        if layer4_1_rate > 0.9:
            layer4_1_rate =0.9
        if layer4_2_rate > 0.9:
            layer4_2_rate =0.9
        if layer4_3_rate > 0.9:
            layer4_3_rate =0.9
        if layer4_4_rate > 0.9:
            layer4_4_rate =0.9
        prune_rate = [layer1_rate, layer2_1_rate, layer2_2_rate, layer3_1_rate, layer3_2_rate, layer3_3_rate, layer3_4_rate,layer4_1_rate, layer4_2_rate, layer4_3_rate, layer4_4_rate]
    else:
        prune_rate = args.blockprobs
        
    logging.info(prune_rate)
    prune_layer1(model, prune_rate[0])
    prune_layer2_1(model, prune_rate[1])
    prune_layer2_2(model, prune_rate[2])
    prune_layer3_1(model, prune_rate[3])
    prune_layer3_2(model, prune_rate[4])
    prune_layer3_3(model, prune_rate[5])
    prune_layer3_4(model, prune_rate[6])
    prune_layer4_1(model, prune_rate[7])
    prune_layer4_2(model, prune_rate[8])
    prune_layer4_3(model, prune_rate[9])
    prune_layer4_4(model, prune_rate[10])

    def prune_layer1_ori(model, rate):
        idxs = prune_block(model.layer1[0], rate)
        for i in range(1, len(model.layer1)):
            structured_my.prune_related_conv(model.layer1[i].conv1, idxs)
            idxs = prune_block(model.layer1[i], rate)
        structured_my.prune_related_conv(model.layer2_1[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer2_1[0].downsample[0], idxs)

    def prune_layer2_1_ori(model, rate):
        idxs = prune_block_with_downsample(model.layer2_1[0], rate)
        for i in range(1, len(model.layer2_1)):
            structured_my.prune_related_conv(model.layer2_1[i].conv1, idxs)
            idxs = prune_block(model.layer2_1[i], rate)
        structured_my.prune_related_conv(model.layer3_1[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer3_1[0].downsample[0], idxs)
        
    def prune_layer3_1_ori(model, rate):
        idxs = prune_block_with_downsample(model.layer3_1[0], rate)
        for i in range(1, len(model.layer3_1)):
            structured_my.prune_related_conv(model.layer3_1[i].conv1, idxs)
            idxs = prune_block(model.layer3_1[i], rate)
        structured_my.prune_related_conv(model.layer4_1[0].conv1, idxs)
        structured_my.prune_related_conv(model.layer4_1[0].downsample[0], idxs)

    def prune_layer4_1_ori(model, rate):
        idxs = prune_block_with_downsample(model.layer4_1[0], rate)
        for i in range(1, len(model.layer4_1)):
            structured_my.prune_related_conv(model.layer4_1[i].conv1, idxs)
            idxs = prune_block(model.layer4_1[i], rate)
        structured_my.prune_related_linear(model.middle_fc1, idxs)
    def calculate_flops(prune_rate,args):
        if args.dataset == 'cifar100' or args.dataset == 'cifar10':
            model2 = models.__dict__[args.arch[:-6]](num_classes=100)
        else:
            model2 = models.__dict__[args.arch[:-6]](num_classes=1000)
        if args.dataset == 'cifar100' or args.dataset == 'cifar10':
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 32, 32),))
        else:
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 224, 224),))
        logging.info('ori')
        logging.info("Number of Parameters: %.3fM" % (params / 1e6))
        logging.info("Number of MACS: %.1fM FLOPS: %.1fM" % (macs / 1e6, 2 * macs / 1e6))

        prune_layer1_ori(model2, prune_rate[0])
        prune_layer2_1_ori(model2, prune_rate[1])
        prune_layer3_1_ori(model2, prune_rate[3])
        prune_layer4_1_ori(model2, prune_rate[7])
        if args.dataset == 'cifar100' or args.dataset == 'cifar10':
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 32, 32),))
        else:
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 224, 224),))
        logging.info('Head1')
        logging.info("Number of Parameters: %.3fM" % (params / 1e6))
        logging.info("Number of MACS: %.1fM FLOPS: %.1fM" % (macs / 1e6, 2 * macs / 1e6))

        if args.dataset == 'cifar100' or args.dataset == 'cifar10':
            model2 = models.__dict__[args.arch[:-6]](num_classes=100)
        else:
            model2 = models.__dict__[args.arch[:-6]](num_classes=1000)
        prune_layer1_ori(model2, prune_rate[0])
        prune_layer2_1_ori(model2, prune_rate[1])
        prune_layer3_1_ori(model2, prune_rate[4])
        prune_layer4_1_ori(model2, prune_rate[8])
        if args.dataset == 'cifar100' or args.dataset == 'cifar10':
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 32, 32),))
        else:
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 224, 224),))
        logging.info('Head2')
        logging.info("Number of Parameters: %.3fM" % (params / 1e6))
        logging.info("Number of MACS: %.1fM FLOPS: %.1fM" % (macs / 1e6, 2 * macs / 1e6))

        if args.dataset == 'cifar100' or args.dataset == 'cifar10':
            model2 = models.__dict__[args.arch[:-6]](num_classes=100)
        else:
            model2 = models.__dict__[args.arch[:-6]](num_classes=1000)
        prune_layer1_ori(model2, prune_rate[0])
        prune_layer2_1_ori(model2, prune_rate[2])
        prune_layer3_1_ori(model2, prune_rate[5])
        prune_layer4_1_ori(model2, prune_rate[9])
        if args.dataset == 'cifar100' or args.dataset == 'cifar10':
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 32, 32),))
        else:
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 224, 224),))
        logging.info('Head3')
        logging.info("Number of Parameters: %.3fM" % (params / 1e6))
        logging.info("Number of MACS: %.1fM FLOPS: %.1fM" % (macs / 1e6, 2 * macs / 1e6))

        if args.dataset == 'cifar100' or args.dataset == 'cifar10':
            model2 = models.__dict__[args.arch[:-6]](num_classes=100)
        else:
            model2 = models.__dict__[args.arch[:-6]](num_classes=1000)
        prune_layer1_ori(model2, prune_rate[0])
        prune_layer2_1_ori(model2, prune_rate[2])
        prune_layer3_1_ori(model2, prune_rate[6])
        prune_layer4_1_ori(model2, prune_rate[10])
        if args.dataset == 'cifar100' or args.dataset == 'cifar10':
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 32, 32),))
        else:
            macs, params = profile(model2.cpu(), inputs=(torch.randn(1, 3, 224, 224),))
        logging.info('Head4')
        logging.info("Number of Parameters: %.3fM" % (params / 1e6))
        logging.info("Number of MACS: %.1fM FLOPS: %.1fM" % (macs / 1e6, 2 * macs / 1e6))


    calculate_flops(prune_rate,args)
    logging.info(prune_rate)
    logging.info(args.headprobs)


    return model.to(device)



def run_one_batch_4head(model,args,train_loader):
    model.train().to(args.device)

    criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=args.momentum, weight_decay = args.weight_decay)


    step = 0

    input, target=next(iter(train_loader))
    input=input[0:4, :, :, :]
    target=target[0:4]
    target = target.squeeze().long().to(args.device)
    input = input.to(args.device)

    middle_output1, middle_output2, middle_output3, middle_output4= model(input)

    middle1_loss = criterion(middle_output1, target)
    middle2_loss = criterion(middle_output2, target)
    middle3_loss = criterion(middle_output3, target)
    middle4_loss = criterion(middle_output4, target)


    output_ensemble=(middle_output1+middle_output2+middle_output3+middle_output4)/4


    temp_ensemble = output_ensemble / args.temperature
    temp_ensemble = torch.softmax(temp_ensemble, dim=1)

    loss1by_en = kd_loss_function(middle_output1, temp_ensemble.detach(), args) * (args.temperature**2)
    loss2by_en = kd_loss_function(middle_output2, temp_ensemble.detach(), args) * (args.temperature**2)
    loss3by_en = kd_loss_function(middle_output3, temp_ensemble.detach(), args) * (args.temperature**2)
    loss4by_en = kd_loss_function(middle_output4, temp_ensemble.detach(), args) * (args.temperature**2)


    total_loss = (1 - args.alpha) * (middle1_loss + middle2_loss + middle3_loss + middle4_loss)/2 + \
                args.alpha * (loss1by_en + loss2by_en + loss3by_en + loss4by_en)/2 #+ \
                #args.beta * (feature_loss_1 + feature_loss_2 + feature_loss_3 + feature_loss_4+feature_loss_5 + feature_loss_6 + feature_loss_7 + feature_loss_8)/2

    model.zero_grad()
    total_loss.backward()





def kd_loss_function(output, target_output,args):
    """Compute kd loss"""
    """
    para: output: middle ouptput logits.
    para: target_output: final output has divided by temperature and softmax.
    """

    output = output / args.temperature
    output_log_softmax = torch.log_softmax(output, dim=1)
    loss_kd = -torch.mean(torch.sum(output_log_softmax * target_output, dim=1))
    return loss_kd

def feature_loss_function(fea, target_fea):
    loss = (fea - target_fea)**2 * ((fea > 0) | (target_fea > 0)).float()
    return torch.abs(loss).sum()